{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dZXh7dogJlHH"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/jeffheaton/app_deep_learning/blob/main/t81_558_class_03_4_early_stop.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mKH1QxMuJlHK"
   },
   "source": [
    "# T81-558: Applications of Deep Neural Networks\n",
    "\n",
    "**Module 3: Introduction to PyTorch**\n",
    "\n",
    "- Instructor: [Jeff Heaton](https://sites.wustl.edu/jeffheaton/), McKelvey School of Engineering, [Washington University in St. Louis](https://engineering.wustl.edu/Programs/Pages/default.aspx)\n",
    "- For more information visit the [class website](https://sites.wustl.edu/jeffheaton/t81-558/).\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fwkbs9-gJlHL"
   },
   "source": [
    "# Module 3 Material\n",
    "\n",
    "- Part 3.1: Deep Learning and Neural Network Introduction [[Video]](https://www.youtube.com/watch?v=d-rU5IuFqLs&list=PLjy4p-07OYzuy_lHcRW8lPTLPTTOmUpmi) [[Notebook]](t81_558_class_03_1_neural_net.ipynb)\n",
    "- Part 3.2: Introduction to PyTorch [[Video]](https://www.youtube.com/watch?v=Pf-rrhMolm0&list=PLjy4p-07OYzuy_lHcRW8lPTLPTTOmUpmi) [[Notebook]](t81_558_class_03_2_pytorch.ipynb)\n",
    "- Part 3.3: Encoding a Feature Vector for PyTorch Deep Learning [[Video]](https://www.youtube.com/watch?v=7SGPm2tIT58&list=PLjy4p-07OYzuy_lHcRW8lPTLPTTOmUpmi) [[Notebook]](t81_558_class_03_3_feature_encode.ipynb)\n",
    "- **Part 3.4: Early Stopping and Network Persistence** [[Video]](https://www.youtube.com/watch?v=lS0vvIWiahU&list=PLjy4p-07OYzuy_lHcRW8lPTLPTTOmUpmi) [[Notebook]](t81_558_class_03_4_early_stop.ipynb)\n",
    "- Part 3.5: Sequences vs Classes in PyTorch [[Video]](https://www.youtube.com/watch?v=NOu8jMZx3LY&list=PLjy4p-07OYzuy_lHcRW8lPTLPTTOmUpmi) [[Notebook]](t81_558_class_03_5_pytorch_class_sequence.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ovYF1H1ZJlHL"
   },
   "source": [
    "# Google CoLab Instructions\n",
    "\n",
    "The following code ensures that Google CoLab is running and maps Google Drive if needed. We also initialize the PyTorch device to either GPU/MPS (if available) or CPU.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "H4wO3BiMJlHM",
    "outputId": "4eca8b22-802a-47c9-eee3-365fad896474"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Note: not using Google CoLab\n",
      "Using device: mps\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "try:\n",
    "    import google.colab\n",
    "\n",
    "    COLAB = True\n",
    "    print(\"Note: using Google CoLab\")\n",
    "except:\n",
    "    print(\"Note: not using Google CoLab\")\n",
    "    COLAB = False\n",
    "\n",
    "# Make use of a GPU or MPS (Apple) if one is available.  (see module 3.2)\n",
    "device = (\n",
    "    \"mps\"\n",
    "    if getattr(torch, \"has_mps\", False)\n",
    "    else \"cuda\"\n",
    "    if torch.cuda.is_available()\n",
    "    else \"cpu\"\n",
    ")\n",
    "print(f\"Using device: {device}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mWo4ptCdJlHN"
   },
   "source": [
    "# Part 3.4: Early Stopping and Network Persistence\n",
    "\n",
    "This module will delve into two essential aspects of training neural networks: early stopping and saving/loading PyTorch networks. Early stopping is a technique that helps prevent overfitting and optimize model performance by monitoring validation loss during training. We can avoid unnecessary iterations and save computational resources by stopping the training process when the validation loss starts increasing. Additionally, we will explore how to save and load PyTorch networks, allowing us to store trained models and reuse them for predictions or further training. Understanding these techniques will empower you to optimize your models' performance and efficiently manage your trained networks.\n",
    "\n",
    "## Early Stopping\n",
    "\n",
    "It can be difficult to determine how many epochs to cycle through to train a neural network. Overfitting will occur if you train the neural network for too many epochs, and the neural network will not perform well on new data, despite attaining a good accuracy on the training set. Overfitting occurs when a neural network is trained to the point that it begins to memorize rather than generalize, as demonstrated in Figure 3.OVER.\n",
    "\n",
    "**Figure 3.OVER: Training vs. Validation Error for Overfitting**\n",
    "![Training vs. Validation Error for Overfitting](https://raw.githubusercontent.com/jeffheaton/t81_558_deep_learning/master/images/class_3_training_val.png \"Training vs. Validation Error for Overfitting\")\n",
    "\n",
    "It is important to segment the original dataset into several datasets:\n",
    "\n",
    "- **Training Set**\n",
    "- **Validation Set**\n",
    "- **Holdout Set**\n",
    "\n",
    "You can construct these sets in several different ways. The following programs demonstrate some of these.\n",
    "\n",
    "The first method is a training and validation set. We use the training data to train the neural network until the validation set no longer improves. This attempts to stop at a near-optimal training point. This method will only give accurate \"out of sample\" predictions for the validation set; this is usually 20% of the data. The predictions for the training data will be overly optimistic, as these were the data that we used to train the neural network. Figure 3.VAL demonstrates how we divide the dataset.\n",
    "\n",
    "**Figure 3.VAL: Training with a Validation Set**\n",
    "![Training with a Validation Set](https://raw.githubusercontent.com/jeffheaton/t81_558_deep_learning/master/images/class_1_train_val.png \"Training with a Validation Set\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qfQxxVVs--7K"
   },
   "source": [
    "Because PyTorch does not include a built-in early stopping function, we must define one of our own. We will use the following **EarlyStopping** class throughout this course.\n",
    "\n",
    "We can provide several parameters to the **EarlyStopping** object:\n",
    "\n",
    "- **min_delta** This value should be kept small; it specifies the minimum change that should be considered an improvement. Setting it even smaller will not likely have a great deal of impact.\n",
    "- **patience** How long should the training wait for the validation error to improve?\n",
    "- **restore_best_weights** You should usually set this to true, as it restores the weights to the values they were at when the validation set is the highest.\n",
    "\n",
    "We will now see an example of this class in action.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "CAezCpVfOFAF"
   },
   "outputs": [],
   "source": [
    "import copy\n",
    "\n",
    "\n",
    "class EarlyStopping:\n",
    "    def __init__(self, patience=5, min_delta=0, restore_best_weights=True):\n",
    "        self.patience = patience\n",
    "        self.min_delta = min_delta\n",
    "        self.restore_best_weights = restore_best_weights\n",
    "        self.best_model = None\n",
    "        self.best_loss = None\n",
    "        self.counter = 0\n",
    "        self.status = \"\"\n",
    "\n",
    "    def __call__(self, model, val_loss):\n",
    "        if self.best_loss is None:\n",
    "            self.best_loss = val_loss\n",
    "            self.best_model = copy.deepcopy(model.state_dict())\n",
    "        elif self.best_loss - val_loss >= self.min_delta:\n",
    "            self.best_model = copy.deepcopy(model.state_dict())\n",
    "            self.best_loss = val_loss\n",
    "            self.counter = 0\n",
    "            self.status = f\"Improvement found, counter reset to {self.counter}\"\n",
    "        else:\n",
    "            self.counter += 1\n",
    "            self.status = f\"No improvement in the last {self.counter} epochs\"\n",
    "            if self.counter >= self.patience:\n",
    "                self.status = f\"Early stopping triggered after {self.counter} epochs.\"\n",
    "                if self.restore_best_weights:\n",
    "                    model.load_state_dict(self.best_model)\n",
    "                return True\n",
    "        return False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "nrsobz8ZJlHO"
   },
   "source": [
    "### Early Stopping with Classification\n",
    "\n",
    "We will now see an example of classification training with early stopping. We will train the neural network until the error no longer improves on the validation set.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch: 1, tloss: 0.6026307344436646, vloss: 0.536555, : 100%|██████████| 7/7 [00:00<00:00, 22.81it/s]\n",
      "Epoch: 2, tloss: 0.3658648133277893, vloss: 0.277725, Improvement found, counter reset to 0: 100%|██████████| 7/7 [00:00<00:00, 852.72it/s]\n",
      "Epoch: 3, tloss: 0.15603026747703552, vloss: 0.187535, Improvement found, counter reset to 0: 100%|██████████| 7/7 [00:00<00:00, 944.27it/s]\n",
      "Epoch: 4, tloss: 0.05794893577694893, vloss: 0.154333, Improvement found, counter reset to 0: 100%|██████████| 7/7 [00:00<00:00, 1005.17it/s]\n",
      "Epoch: 5, tloss: 0.18528978526592255, vloss: 0.076723, Improvement found, counter reset to 0: 100%|██████████| 7/7 [00:00<00:00, 1021.47it/s]\n",
      "Epoch: 6, tloss: 0.1242004930973053, vloss: 0.061499, Improvement found, counter reset to 0: 100%|██████████| 7/7 [00:00<00:00, 458.05it/s]\n",
      "Epoch: 7, tloss: 0.03340418264269829, vloss: 0.045322, Improvement found, counter reset to 0: 100%|██████████| 7/7 [00:00<00:00, 333.21it/s]\n",
      "Epoch: 8, tloss: 0.09452518075704575, vloss: 0.032975, Improvement found, counter reset to 0: 100%|██████████| 7/7 [00:00<00:00, 464.86it/s]\n",
      "Epoch: 9, tloss: 0.005208526272326708, vloss: 0.023963, Improvement found, counter reset to 0: 100%|██████████| 7/7 [00:00<00:00, 881.08it/s]\n",
      "Epoch: 10, tloss: 0.06230872869491577, vloss: 0.015515, Improvement found, counter reset to 0: 100%|██████████| 7/7 [00:00<00:00, 760.03it/s]\n",
      "Epoch: 11, tloss: 0.08908900618553162, vloss: 0.038120, No improvement in the last 1 epochs: 100%|██████████| 7/7 [00:00<00:00, 291.66it/s]\n",
      "Epoch: 12, tloss: 0.03496554493904114, vloss: 0.026789, No improvement in the last 2 epochs: 100%|██████████| 7/7 [00:00<00:00, 729.70it/s]\n",
      "Epoch: 13, tloss: 0.06976647675037384, vloss: 0.018425, No improvement in the last 3 epochs: 100%|██████████| 7/7 [00:00<00:00, 670.11it/s]\n",
      "Epoch: 14, tloss: 0.013938345946371555, vloss: 0.010584, Improvement found, counter reset to 0: 100%|██████████| 7/7 [00:00<00:00, 940.25it/s]\n",
      "Epoch: 15, tloss: 0.03223254159092903, vloss: 0.008987, Improvement found, counter reset to 0: 100%|██████████| 7/7 [00:00<00:00, 1003.87it/s]\n",
      "Epoch: 16, tloss: 0.009036796167492867, vloss: 0.016889, No improvement in the last 1 epochs: 100%|██████████| 7/7 [00:00<00:00, 1007.90it/s]\n",
      "Epoch: 17, tloss: 0.009504576213657856, vloss: 0.014087, No improvement in the last 2 epochs: 100%|██████████| 7/7 [00:00<00:00, 1043.69it/s]\n",
      "Epoch: 18, tloss: 0.05779397487640381, vloss: 0.012317, No improvement in the last 3 epochs: 100%|██████████| 7/7 [00:00<00:00, 857.93it/s]\n",
      "Epoch: 19, tloss: 0.001863109297119081, vloss: 0.012097, No improvement in the last 4 epochs: 100%|██████████| 7/7 [00:00<00:00, 866.67it/s]\n",
      "Epoch: 20, tloss: 0.010492893867194653, vloss: 0.011086, Early stopping triggered after 5 epochs.: 100%|██████████| 7/7 [00:00<00:00, 936.95it/s]\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import tqdm\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import LabelEncoder, StandardScaler\n",
    "from torch import nn\n",
    "from torch.autograd import Variable\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "# Set random seed for reproducibility\n",
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "\n",
    "def load_data():\n",
    "    df = pd.read_csv(\n",
    "        \"https://data.heatonresearch.com/data/t81-558/iris.csv\", na_values=[\"NA\", \"?\"]\n",
    "    )\n",
    "\n",
    "    le = LabelEncoder()\n",
    "\n",
    "    x = df[[\"sepal_l\", \"sepal_w\", \"petal_l\", \"petal_w\"]].values\n",
    "    y = le.fit_transform(df[\"species\"])\n",
    "    species = le.classes_\n",
    "\n",
    "    # Split into validation and training sets\n",
    "    x_train, x_test, y_train, y_test = train_test_split(\n",
    "        x, y, test_size=0.25, random_state=42\n",
    "    )\n",
    "\n",
    "    scaler = StandardScaler()\n",
    "    x_train = scaler.fit_transform(x_train)\n",
    "    x_test = scaler.transform(x_test)\n",
    "\n",
    "    # Numpy to Torch Tensor\n",
    "    x_train = torch.tensor(x_train, device=device, dtype=torch.float32)\n",
    "    y_train = torch.tensor(y_train, device=device, dtype=torch.long)\n",
    "\n",
    "    x_test = torch.tensor(x_test, device=device, dtype=torch.float32)\n",
    "    y_test = torch.tensor(y_test, device=device, dtype=torch.long)\n",
    "\n",
    "    return x_train, x_test, y_train, y_test, species\n",
    "\n",
    "\n",
    "x_train, x_test, y_train, y_test, species = load_data()\n",
    "\n",
    "# Create datasets\n",
    "BATCH_SIZE = 16\n",
    "\n",
    "dataset_train = TensorDataset(x_train, y_train)\n",
    "dataloader_train = DataLoader(\n",
    "    dataset_train, batch_size=BATCH_SIZE, shuffle=True)\n",
    "\n",
    "dataset_test = TensorDataset(x_test, y_test)\n",
    "dataloader_test = DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=True)\n",
    "\n",
    "# Create model using nn.Sequential\n",
    "model = nn.Sequential(\n",
    "    nn.Linear(x_train.shape[1], 50),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(50, 25),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(25, len(species)),\n",
    "    nn.LogSoftmax(dim=1),\n",
    ")\n",
    "\n",
    "model = torch.compile(model,backend=\"aot_eager\").to(device)\n",
    "\n",
    "loss_fn = nn.CrossEntropyLoss()  # cross entropy loss\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n",
    "es = EarlyStopping()\n",
    "\n",
    "epoch = 0\n",
    "done = False\n",
    "while epoch < 1000 and not done:\n",
    "    epoch += 1\n",
    "    steps = list(enumerate(dataloader_train))\n",
    "    pbar = tqdm.tqdm(steps)\n",
    "    model.train()\n",
    "    for i, (x_batch, y_batch) in pbar:\n",
    "        y_batch_pred = model(x_batch.to(device))\n",
    "        loss = loss_fn(y_batch_pred, y_batch.to(device))\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        loss, current = loss.item(), (i + 1) * len(x_batch)\n",
    "        if i == len(steps) - 1:\n",
    "            model.eval()\n",
    "            pred = model(x_test)\n",
    "            vloss = loss_fn(pred, y_test)\n",
    "            if es(model, vloss):\n",
    "                done = True\n",
    "            pbar.set_description(\n",
    "                f\"Epoch: {epoch}, tloss: {loss}, vloss: {vloss:>7f}, {es.status}\"\n",
    "            )\n",
    "        else:\n",
    "            pbar.set_description(f\"Epoch: {epoch}, tloss {loss:}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "aJCDY-FcP41U",
    "outputId": "901b0a97-32cd-443e-d89c-4c47985c633b"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss = 0.008987346664071083\n"
     ]
    }
   ],
   "source": [
    "pred = model(x_test)\n",
    "vloss = loss_fn(pred, y_test)\n",
    "print(f\"Loss = {vloss}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ATJhTzRjJlHQ"
   },
   "source": [
    "As you can see from above, we did not use the total number of requested epochs. The neural network training stopped once the validation set no longer improved.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "A0iNHDxNJlHR",
    "outputId": "9c60711a-d28b-4360-92d8-2205fe0d7fcc"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 1.0\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "pred = model(x_test)\n",
    "_, predict_classes = torch.max(pred, 1)\n",
    "correct = accuracy_score(y_test.cpu(), predict_classes.cpu())\n",
    "print(f\"Accuracy: {correct}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "nR03ea5QJlHS"
   },
   "source": [
    "### Early Stopping with Regression\n",
    "\n",
    "The following code demonstrates how we can apply early stopping to a regression problem. The technique is similar to the early stopping for classification code that we just saw.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "pTuEcZE4JlHS",
    "outputId": "0dffc406-fb3b-41ad-bb34-e3f9f942e725"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch: 1, tloss: 261.34014892578125, vloss: 672.386719, EStop:[]: 100%|██████████| 19/19 [00:00<00:00, 95.23it/s] \n",
      "Epoch: 2, tloss: 189.8587646484375, vloss: 192.848831, EStop:[Improvement found, counter reset to 0]: 100%|██████████| 19/19 [00:00<00:00, 608.19it/s]\n",
      "Epoch: 3, tloss: 206.1345672607422, vloss: 180.863358, EStop:[Improvement found, counter reset to 0]: 100%|██████████| 19/19 [00:00<00:00, 362.06it/s]\n",
      "Epoch: 4, tloss: 189.10604858398438, vloss: 174.722351, EStop:[Improvement found, counter reset to 0]: 100%|██████████| 19/19 [00:00<00:00, 710.34it/s]\n",
      "Epoch: 5, tloss: 200.2677459716797, vloss: 170.232437, EStop:[Improvement found, counter reset to 0]: 100%|██████████| 19/19 [00:00<00:00, 726.50it/s]\n",
      "Epoch: 6, tloss: 180.11233520507812, vloss: 165.325562, EStop:[Improvement found, counter reset to 0]: 100%|██████████| 19/19 [00:00<00:00, 631.19it/s]\n",
      "Epoch: 7, tloss: 122.46978759765625, vloss: 152.626938, EStop:[Improvement found, counter reset to 0]: 100%|██████████| 19/19 [00:00<00:00, 641.74it/s]\n",
      "Epoch: 8, tloss: 142.17471313476562, vloss: 144.341217, EStop:[Improvement found, counter reset to 0]: 100%|██████████| 19/19 [00:00<00:00, 643.18it/s]\n",
      "Epoch: 9, tloss: 67.58553314208984, vloss: 133.457550, EStop:[Improvement found, counter reset to 0]: 100%|██████████| 19/19 [00:00<00:00, 777.28it/s]\n",
      "Epoch: 10, tloss: 75.73152923583984, vloss: 118.177582, EStop:[Improvement found, counter reset to 0]: 100%|██████████| 19/19 [00:00<00:00, 756.43it/s]\n",
      "Epoch: 11, tloss: 92.48855590820312, vloss: 126.596443, EStop:[No improvement in the last 1 epochs]: 100%|██████████| 19/19 [00:00<00:00, 569.22it/s]\n",
      "Epoch: 12, tloss: 73.30276489257812, vloss: 86.829872, EStop:[Improvement found, counter reset to 0]: 100%|██████████| 19/19 [00:00<00:00, 441.27it/s]\n",
      "Epoch: 13, tloss: 100.01966857910156, vloss: 69.598763, EStop:[Improvement found, counter reset to 0]: 100%|██████████| 19/19 [00:00<00:00, 629.51it/s]\n",
      "Epoch: 14, tloss: 62.98884201049805, vloss: 55.573502, EStop:[Improvement found, counter reset to 0]: 100%|██████████| 19/19 [00:00<00:00, 870.96it/s]\n",
      "Epoch: 15, tloss: 42.85931396484375, vloss: 43.179443, EStop:[Improvement found, counter reset to 0]: 100%|██████████| 19/19 [00:00<00:00, 944.70it/s]\n",
      "Epoch: 16, tloss: 40.424705505371094, vloss: 50.241440, EStop:[No improvement in the last 1 epochs]: 100%|██████████| 19/19 [00:00<00:00, 307.35it/s]\n",
      "Epoch: 17, tloss: 5.7408671379089355, vloss: 54.786167, EStop:[No improvement in the last 2 epochs]: 100%|██████████| 19/19 [00:00<00:00, 1024.26it/s]\n",
      "Epoch: 18, tloss: 37.34910583496094, vloss: 31.465496, EStop:[Improvement found, counter reset to 0]: 100%|██████████| 19/19 [00:00<00:00, 691.34it/s]\n",
      "Epoch: 19, tloss: 47.07156753540039, vloss: 60.890282, EStop:[No improvement in the last 1 epochs]: 100%|██████████| 19/19 [00:00<00:00, 667.89it/s]\n",
      "Epoch: 20, tloss: 17.795011520385742, vloss: 50.140453, EStop:[No improvement in the last 2 epochs]: 100%|██████████| 19/19 [00:00<00:00, 678.33it/s]\n",
      "Epoch: 21, tloss: 26.97421646118164, vloss: 27.666336, EStop:[Improvement found, counter reset to 0]: 100%|██████████| 19/19 [00:00<00:00, 943.39it/s]\n",
      "Epoch: 22, tloss: 32.7357292175293, vloss: 24.210747, EStop:[Improvement found, counter reset to 0]: 100%|██████████| 19/19 [00:00<00:00, 604.10it/s]\n",
      "Epoch: 23, tloss: 26.935253143310547, vloss: 23.672142, EStop:[Improvement found, counter reset to 0]: 100%|██████████| 19/19 [00:00<00:00, 783.03it/s]\n",
      "Epoch: 24, tloss: 44.157379150390625, vloss: 21.878456, EStop:[Improvement found, counter reset to 0]: 100%|██████████| 19/19 [00:00<00:00, 306.38it/s]\n",
      "Epoch: 25, tloss: 34.703861236572266, vloss: 21.751738, EStop:[Improvement found, counter reset to 0]: 100%|██████████| 19/19 [00:00<00:00, 930.37it/s]\n",
      "Epoch: 26, tloss: 18.24307632446289, vloss: 20.961008, EStop:[Improvement found, counter reset to 0]: 100%|██████████| 19/19 [00:00<00:00, 1076.80it/s]\n",
      "Epoch: 27, tloss: 17.406925201416016, vloss: 18.943779, EStop:[Improvement found, counter reset to 0]: 100%|██████████| 19/19 [00:00<00:00, 357.26it/s]\n",
      "Epoch: 28, tloss: 13.008171081542969, vloss: 23.957783, EStop:[No improvement in the last 1 epochs]: 100%|██████████| 19/19 [00:00<00:00, 721.64it/s]\n",
      "Epoch: 29, tloss: 33.604469299316406, vloss: 16.431408, EStop:[Improvement found, counter reset to 0]: 100%|██████████| 19/19 [00:00<00:00, 996.33it/s]\n",
      "Epoch: 30, tloss: 33.36851501464844, vloss: 16.989378, EStop:[No improvement in the last 1 epochs]: 100%|██████████| 19/19 [00:00<00:00, 356.00it/s]\n",
      "Epoch: 31, tloss: 20.354137420654297, vloss: 23.444510, EStop:[No improvement in the last 2 epochs]: 100%|██████████| 19/19 [00:00<00:00, 639.43it/s]\n",
      "Epoch: 32, tloss: 7.3451247215271, vloss: 14.537480, EStop:[Improvement found, counter reset to 0]: 100%|██████████| 19/19 [00:00<00:00, 1087.70it/s]\n",
      "Epoch: 33, tloss: 15.939814567565918, vloss: 17.934065, EStop:[No improvement in the last 1 epochs]: 100%|██████████| 19/19 [00:00<00:00, 1024.04it/s]\n",
      "Epoch: 34, tloss: 18.71426010131836, vloss: 13.295027, EStop:[Improvement found, counter reset to 0]: 100%|██████████| 19/19 [00:00<00:00, 339.66it/s]\n",
      "Epoch: 35, tloss: 7.033115386962891, vloss: 12.081243, EStop:[Improvement found, counter reset to 0]: 100%|██████████| 19/19 [00:00<00:00, 768.36it/s]\n",
      "Epoch: 36, tloss: 8.607196807861328, vloss: 11.788025, EStop:[Improvement found, counter reset to 0]: 100%|██████████| 19/19 [00:00<00:00, 976.40it/s]\n",
      "Epoch: 37, tloss: 2.295261859893799, vloss: 12.488452, EStop:[No improvement in the last 1 epochs]: 100%|██████████| 19/19 [00:00<00:00, 1044.58it/s]\n",
      "Epoch: 38, tloss: 13.847299575805664, vloss: 13.275513, EStop:[No improvement in the last 2 epochs]: 100%|██████████| 19/19 [00:00<00:00, 353.46it/s]\n",
      "Epoch: 39, tloss: 17.288166046142578, vloss: 11.650891, EStop:[Improvement found, counter reset to 0]: 100%|██████████| 19/19 [00:00<00:00, 627.89it/s]\n",
      "Epoch: 40, tloss: 20.6008358001709, vloss: 19.881641, EStop:[No improvement in the last 1 epochs]: 100%|██████████| 19/19 [00:00<00:00, 1009.02it/s]\n",
      "Epoch: 41, tloss: 20.57865333557129, vloss: 16.989870, EStop:[No improvement in the last 2 epochs]: 100%|██████████| 19/19 [00:00<00:00, 400.12it/s]\n",
      "Epoch: 42, tloss: 16.719619750976562, vloss: 10.719060, EStop:[Improvement found, counter reset to 0]: 100%|██████████| 19/19 [00:00<00:00, 697.43it/s]\n",
      "Epoch: 43, tloss: 7.9366607666015625, vloss: 9.597138, EStop:[Improvement found, counter reset to 0]: 100%|██████████| 19/19 [00:00<00:00, 990.26it/s]\n",
      "Epoch: 44, tloss: 29.034564971923828, vloss: 9.041054, EStop:[Improvement found, counter reset to 0]: 100%|██████████| 19/19 [00:00<00:00, 333.93it/s]\n",
      "Epoch: 45, tloss: 3.965550184249878, vloss: 10.987783, EStop:[No improvement in the last 1 epochs]: 100%|██████████| 19/19 [00:00<00:00, 784.19it/s]\n",
      "Epoch: 46, tloss: 7.56799840927124, vloss: 9.403233, EStop:[No improvement in the last 2 epochs]: 100%|██████████| 19/19 [00:00<00:00, 1031.66it/s]\n",
      "Epoch: 47, tloss: 9.794296264648438, vloss: 8.315845, EStop:[Improvement found, counter reset to 0]: 100%|██████████| 19/19 [00:00<00:00, 1042.98it/s]\n",
      "Epoch: 48, tloss: 18.71722412109375, vloss: 8.309561, EStop:[Improvement found, counter reset to 0]: 100%|██████████| 19/19 [00:00<00:00, 649.13it/s]\n",
      "Epoch: 49, tloss: 2.794205665588379, vloss: 10.559619, EStop:[No improvement in the last 1 epochs]: 100%|██████████| 19/19 [00:00<00:00, 512.50it/s]\n",
      "Epoch: 50, tloss: 2.9968934059143066, vloss: 8.354447, EStop:[No improvement in the last 2 epochs]: 100%|██████████| 19/19 [00:00<00:00, 1079.85it/s]\n",
      "Epoch: 51, tloss: 7.316162109375, vloss: 8.509170, EStop:[No improvement in the last 3 epochs]: 100%|██████████| 19/19 [00:00<00:00, 905.45it/s]\n",
      "Epoch: 52, tloss: 3.4442672729492188, vloss: 8.402485, EStop:[No improvement in the last 4 epochs]: 100%|██████████| 19/19 [00:00<00:00, 1058.69it/s]\n",
      "Epoch: 53, tloss: 17.530012130737305, vloss: 17.295694, EStop:[Early stopping triggered after 5 epochs.]: 100%|██████████| 19/19 [00:00<00:00, 252.36it/s]\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import tqdm\n",
    "from sklearn import preprocessing\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.model_selection import train_test_split\n",
    "from torch.autograd import Variable\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "# Read the MPG dataset.\n",
    "df = pd.read_csv(\n",
    "    \"https://data.heatonresearch.com/data/t81-558/auto-mpg.csv\", na_values=[\"NA\", \"?\"]\n",
    ")\n",
    "\n",
    "cars = df[\"name\"]\n",
    "\n",
    "# Handle missing value\n",
    "df[\"horsepower\"] = df[\"horsepower\"].fillna(df[\"horsepower\"].median())\n",
    "\n",
    "# Pandas to Numpy\n",
    "x = df[\n",
    "    [\n",
    "        \"cylinders\",\n",
    "        \"displacement\",\n",
    "        \"horsepower\",\n",
    "        \"weight\",\n",
    "        \"acceleration\",\n",
    "        \"year\",\n",
    "        \"origin\",\n",
    "    ]\n",
    "].values\n",
    "y = df[\"mpg\"].values  # regression\n",
    "\n",
    "# Split into validation and training sets\n",
    "x_train, x_test, y_train, y_test = train_test_split(\n",
    "    x, y, test_size=0.25, random_state=42\n",
    ")\n",
    "\n",
    "# Numpy to Torch Tensor\n",
    "x_train = torch.tensor(x_train, device=device, dtype=torch.float32)\n",
    "y_train = torch.tensor(y_train, device=device, dtype=torch.float32)\n",
    "\n",
    "x_test = torch.tensor(x_test, device=device, dtype=torch.float32)\n",
    "y_test = torch.tensor(y_test, device=device, dtype=torch.float32)\n",
    "\n",
    "\n",
    "# Create datasets\n",
    "BATCH_SIZE = 16\n",
    "\n",
    "dataset_train = TensorDataset(x_train, y_train)\n",
    "dataloader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)\n",
    "\n",
    "dataset_test = TensorDataset(x_test, y_test)\n",
    "dataloader_test = DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=True)\n",
    "\n",
    "\n",
    "# Create model\n",
    "\n",
    "model = nn.Sequential(\n",
    "    nn.Linear(x_train.shape[1], 50), \n",
    "    nn.ReLU(), \n",
    "    nn.Linear(50, 25), \n",
    "    nn.ReLU(), \n",
    "    nn.Linear(25, 1)\n",
    ")\n",
    "\n",
    "model = torch.compile(model, backend=\"aot_eager\").to(device)\n",
    "\n",
    "# Define the loss function for regression\n",
    "loss_fn = nn.MSELoss()\n",
    "\n",
    "# Define the optimizer\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n",
    "\n",
    "es = EarlyStopping()\n",
    "\n",
    "epoch = 0\n",
    "done = False\n",
    "while epoch < 1000 and not done:\n",
    "    epoch += 1\n",
    "    steps = list(enumerate(dataloader_train))\n",
    "    pbar = tqdm.tqdm(steps)\n",
    "    model.train()\n",
    "    for i, (x_batch, y_batch) in pbar:\n",
    "        y_batch_pred = model(x_batch).flatten()  #\n",
    "        loss = loss_fn(y_batch_pred, y_batch)\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        loss, current = loss.item(), (i + 1) * len(x_batch)\n",
    "        if i == len(steps) - 1:\n",
    "            model.eval()\n",
    "            pred = model(x_test).flatten()\n",
    "            vloss = loss_fn(pred, y_test)\n",
    "            if es(model, vloss):\n",
    "                done = True\n",
    "            pbar.set_description(\n",
    "                f\"Epoch: {epoch}, tloss: {loss}, vloss: {vloss:>7f}, EStop:[{es.status}]\"\n",
    "            )\n",
    "        else:\n",
    "            pbar.set_description(f\"Epoch: {epoch}, tloss {loss:}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wjvaHmp5JlHS"
   },
   "source": [
    "Finally, we evaluate the error.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "0bvqiX-AJlHS",
    "outputId": "b2cb2433-2bfd-4e28-ce4d-ba89d3b03d4d"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Final score (RMSE): 2.8826308250427246\n"
     ]
    }
   ],
   "source": [
    "from sklearn import metrics\n",
    "\n",
    "# Measure RMSE error.  RMSE is common for regression.\n",
    "pred = model(x_test)\n",
    "score = torch.sqrt(torch.nn.functional.mse_loss(pred.flatten(), y_test))\n",
    "print(f\"Final score (RMSE): {score}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "WYi-h2LNXqoZ",
    "outputId": "14669b13-b693-44d4-de6b-7150e9629302"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([33.0000, 28.0000, 19.0000, 13.0000, 14.0000, 27.0000, 24.0000, 13.0000,\n",
       "        17.0000, 21.0000, 15.0000, 38.0000, 26.0000, 15.0000, 25.0000, 12.0000,\n",
       "        31.0000, 17.0000, 16.0000, 31.0000, 22.0000, 22.0000, 22.0000, 33.5000,\n",
       "        18.0000, 44.0000, 26.0000, 24.5000, 18.1000, 12.0000, 27.0000, 36.0000,\n",
       "        23.0000, 24.0000, 37.2000, 16.0000, 21.0000, 19.2000, 16.0000, 29.0000,\n",
       "        26.8000, 27.0000, 18.0000, 10.0000, 23.0000, 36.0000, 26.0000, 25.0000,\n",
       "        25.0000, 25.0000, 22.0000, 34.1000, 32.4000, 13.0000, 23.5000, 14.0000,\n",
       "        18.5000, 29.8000, 28.0000, 19.0000, 11.0000, 33.0000, 23.0000, 21.0000,\n",
       "        23.0000, 25.0000, 23.8000, 34.4000, 24.5000, 13.0000, 34.7000, 14.0000,\n",
       "        15.0000, 18.0000, 25.0000, 19.9000, 17.5000, 28.0000, 29.0000, 17.0000,\n",
       "        16.0000, 27.0000, 37.0000, 36.1000, 23.0000, 14.0000, 32.8000, 29.9000,\n",
       "        20.0000, 12.0000, 15.5000, 23.7000, 24.0000, 36.0000, 19.0000, 38.0000,\n",
       "        29.0000, 21.5000, 27.9000, 14.0000])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dqUztPo3JlHT"
   },
   "source": [
    "## Saving and Loading a PyTorch Neural Network\n",
    "\n",
    "Complex neural networks will take a long time to fit/train. It is helpful to be able to save these neural networks so that you can reload them later. A reloaded neural network will not require retraining. PyTorch usually saves neural networks as [pickle](https://wiki.python.org/moin/UsingPickle) files. The following code trains a neural network to predict car MPG and saves the model.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0, loss: 743.8233032226562\n",
      "Epoch 100, loss: 14.934571266174316\n",
      "Epoch 200, loss: 10.297266006469727\n",
      "Epoch 300, loss: 8.763288497924805\n",
      "Epoch 400, loss: 6.330925941467285\n",
      "Epoch 500, loss: 5.86508846282959\n",
      "Epoch 600, loss: 5.130534648895264\n",
      "Epoch 700, loss: 5.078980922698975\n",
      "Epoch 800, loss: 6.33461332321167\n",
      "Epoch 900, loss: 5.402225971221924\n",
      "Before save score (RMSE): 1.842262864112854\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "# For reproducibility\n",
    "torch.manual_seed(0)\n",
    "np.random.seed(0)\n",
    "\n",
    "# Read the MPG dataset.\n",
    "df = pd.read_csv(\n",
    "    \"https://data.heatonresearch.com/data/t81-558/auto-mpg.csv\", na_values=[\"NA\", \"?\"]\n",
    ")\n",
    "\n",
    "# Handle missing value\n",
    "df[\"horsepower\"] = df[\"horsepower\"].fillna(df[\"horsepower\"].median())\n",
    "\n",
    "# Select features and target\n",
    "features = df[\n",
    "    [\n",
    "        \"cylinders\",\n",
    "        \"displacement\",\n",
    "        \"horsepower\",\n",
    "        \"weight\",\n",
    "        \"acceleration\",\n",
    "        \"year\",\n",
    "        \"origin\",\n",
    "    ]\n",
    "]\n",
    "target = df[\"mpg\"]\n",
    "\n",
    "# Normalize the features\n",
    "scaler = StandardScaler()\n",
    "scaled_features = scaler.fit_transform(features)\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# Convert Numpy to PyTorch tensors\n",
    "features_tensor = torch.tensor(\n",
    "    scaled_features, device=device, dtype=torch.float32)\n",
    "target_tensor = torch.tensor(target.values, device=device, dtype=torch.float32)\n",
    "\n",
    "# Convert to TensorDataset\n",
    "dataset = TensorDataset(features_tensor, target_tensor)\n",
    "\n",
    "# Convert to DataLoader\n",
    "data_loader = DataLoader(dataset, batch_size=32)\n",
    "\n",
    "# Define the neural network using nn.Sequential\n",
    "model = nn.Sequential(\n",
    "    nn.Linear(features_tensor.shape[1], 50),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(50, 25),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(25, 1),\n",
    ").to(device)\n",
    "\n",
    "# Define the loss function for regression\n",
    "loss_fn = nn.MSELoss()\n",
    "\n",
    "# Define the optimizer\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n",
    "\n",
    "# Train for 1000 epochs.\n",
    "model.train()\n",
    "for epoch in range(1000):\n",
    "    for batch_features, batch_target in data_loader:\n",
    "        optimizer.zero_grad()\n",
    "        out = model(batch_features).flatten()\n",
    "        loss = loss_fn(out, batch_target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    # Display status every 100 epochs.\n",
    "    if epoch % 100 == 0:\n",
    "        print(f\"Epoch {epoch}, loss: {loss.item()}\")\n",
    "\n",
    "model.eval()\n",
    "pred = model(features_tensor)\n",
    "\n",
    "# Measure RMSE error.  RMSE is common for regression.\n",
    "score = torch.sqrt(torch.nn.functional.mse_loss(pred.flatten(), target_tensor))\n",
    "print(f\"Before save score (RMSE): {score}\")\n",
    "torch.save(model, \"mpg.pkl\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The code below sets up a neural network and reads the data (for predictions), but it does not clear the model directory or fit the neural network. The code loads the weights from the previous fit. Now we reload the network and perform another prediction. The RMSE should match the previous one exactly if we saved and reloaded the neural network correctly.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Before save score (RMSE): 1.842262864112854\n"
     ]
    }
   ],
   "source": [
    "# Measure RMSE error for loaded network.  RMSE is common for regression.\n",
    "model.eval()\n",
    "pred = model(features_tensor)\n",
    "score = torch.sqrt(torch.nn.functional.mse_loss(pred.flatten(), target_tensor))\n",
    "print(f\"Before save score (RMSE): {score}\")\n",
    "torch.save(model, \"mpg.pkl\")"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "anaconda-cloud": {},
  "colab": {
   "collapsed_sections": [],
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3.9 (torch)",
   "language": "python",
   "name": "pytorch"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
